Skip to content

[gluon][pa_mqa_logits] memory-safety: mask all OutLogits buffer_store lanes (does NOT fix long-context accuracy)#2936

Draft
maeehart wants to merge 3 commits into
ROCm:mainfrom
maeehart:fix/gluon-pa-mqa-outlogits-store-masks
Draft

[gluon][pa_mqa_logits] memory-safety: mask all OutLogits buffer_store lanes (does NOT fix long-context accuracy)#2936
maeehart wants to merge 3 commits into
ROCm:mainfrom
maeehart:fix/gluon-pa-mqa-outlogits-store-masks

Conversation

@maeehart
Copy link
Copy Markdown
Contributor

@maeehart maeehart commented Apr 28, 2026

Summary (memory safety only — see "Scope" below)

Several gl.amd.cdna3.buffer_store(ptr=OutLogits_buffer, …) sites in the Gluon
pa_mqa_logits kernels write float32 logits at logical column indices col
without an upper bound on col < max_model_len. Existing predicates only cover
the lower bound (e.g. col >= 0 or col >= split_context_start). When
context_length == max_model_len and split_context_length is rounded up to
the next KVBlockSize, the inner-loop tail can issue stores up to
(KVBlockSize − 1) + (ChunkKPerStage − 1) columns past the end of the
OutLogits_buffer allocation. That is the textbook unmasked-buffer_store
overshoot pattern.

This patch combines (col < max_model_len) with the existing predicate via
& at every OutLogits_buffer store site in:

  • _gluon_deepgemm_fp8_paged_mqa_logits
  • _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle
  • _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx

So lanes with col >= max_model_len are predicated off and no longer issue a
buffer_store. Nothing else in the kernel body (KV/scale loads, MFMA,
gl.reduce, the tl.where that fills o) is touched, so for any in-bounds
column the value of OutLogits_buffer[batch, col] is bit-identical before and
after this PR.

Scope (read this first)

  • This PR is a memory-safety / OOB-write fix only. It removes
    undefined-behaviour writes that can produce HIP "Memory access fault by GPU"
    (MAF) failures or silently corrupt neighbouring memory.
  • This PR does not claim to change numerical answers for in-bounds
    columns. By construction it cannot — the upstream computation path is
    untouched.
  • Therefore, this PR also does not address the documented long-context
    numerical regression in _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle
    on gfx950 (top-k mismatch vs. deepgemm_fp8_paged_mqa_logits_stage1 once
    context_len exceeds ~2048; see vLLM
    #39303). That accuracy
    bug is a separate kernel-integration issue (SplitKV tiling, KV-block
    addressing in the LoadBlockIndiceForEachStage branch, MFMA operand
    orientation) and will be addressed in a follow-up.

Worked bounds example

With max_model_len = 100, ChunkKPerStage = 32, context_idx = 75: lane
columns 75 … 106 were emitted. Only 75 … 99 are valid;
the last 7 lanes were performing OOB stores. With this PR they are
predicated off.

With ChunkK = 256, ChunkKPerStage = 128, max_model_len = 5120,
context_idx = 5040: indices up to 5167 are emitted; 48 tail lanes
past 5119 were performing OOB stores and are now predicated off.

Validation

Gate Status / Role
Ruff / Black on pa_mqa_logits.py ✅ green
Repository style/dependency checks ✅ green
Triton MI35X test shards (1 / 2 / 3 / 4 / 5 / 6 / 7) ✅ green
OPUS Tests (MI35X) ✅ green
Standard 1-GPU MI35X shards (0 / 1 / 4) ✅ green
Standard 1-GPU MI35X shard 2 ⚠ failure (under triage; pre-existing on main per author's bisect)
Standard 1-GPU MI300X shards ⚠ failures (pre-existing on main, not caused by this PR; see CI artifacts)
Independent vLLM-side parallel of this fix (over-allocate _cached_paged_logits by +256 cols) ✅ closed an MAF-class crash that previously reproduced reliably on long-context MTP, 20 / 20 trials clean back-to-back at n_spec=1, c=4 on MI355X

The vLLM-side variant is the operational equivalent of this PR (it lets the
kernel's overshoot land in safe row padding instead of relying on the kernel
to predicate). Both are valid mitigations for the same memory-safety bug; the
kernel-side fix in this PR removes the dependency on caller-side padding and
is the right layer for the fix.

Patch reproducibility

Regenerate via the supplied diff or cherry-pick from the branch tip; only
buffer_store predicates change.

Checklist

  • Ruff / Black clean on touched file
  • All listed kernels: (col < max_model_len) &-combined with prior
    predicate on every OutLogits_buffer store
  • Confirmed by external (vLLM-side) operational equivalent that the
    OOB write is the cause of the MAF-class crash
  • Upstream ROCm CI fully green on the latest commit before promoting
    from draft to ready-for-review

…le+varctx)

Guard every gl.amd.cdna3.buffer_store targeting OutLogits_buffer with
col < max_model_len, AND-ing with existing split_context_start masks
where present. Prevents SIMD lanes from writing past allocated logits
length on long-context / SplitKV tile boundaries.

See ROCm/aiter community validation: numerical bounds worked example in PR.

Made-with: Cursor
@maeehart maeehart requested review from a team and Copilot April 28, 2026 09:46
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2936 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes potential out-of-bounds writes in the Gluon pa_mqa_logits preshuffle kernels by ensuring every gl.amd.cdna3.buffer_store into OutLogits_buffer is predicated so SIMD lanes with logical column index col >= max_model_len do not execute the store. This targets the _preshuffle and _preshuffle_varctx kernel variants implicated in HIP memory-access faults on long-context paths.

Changes:

  • Add col < max_model_len upper-bound predicates to multiple OutLogits_buffer buffer_store sites in _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle.
  • Add the same upper-bound store masking in _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx.
  • Where a lower-bound predicate already existed (e.g. >= split_context_start), combine it with the new upper bound via &.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 594 to +598
+ (
context_idx
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
),
mask=context_idx
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
>= split_context_start,
mask=(
…tores

Align _gluon_deepgemm_fp8_paged_mqa_logits buffer_store with logits allocation:
AND (col < max_model_len) onto existing >= 0 predicates (col is index into
OutLogits rows). Same correctness class as preshuffle path; addresses review
asking to cover non-preshuffle variant.

Made-with: Cursor
@maeehart maeehart marked this pull request as draft April 28, 2026 11:56
@maeehart maeehart changed the title [gluon][pa_mqa_logits] Mask all OutLogits buffer_store lanes (preshuffle + varctx) [gluon][pa_mqa_logits] memory-safety: mask all OutLogits buffer_store lanes (does NOT fix long-context accuracy) Apr 28, 2026
@azaidy azaidy requested a review from cagrikymk April 29, 2026 14:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants